import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics


def preprocess(xx, yy):
    xx = tf.cast(xx, dtype=tf.float32) / 255.
    # xx = tf.reshape(xx, [-1, 784])  # 28 * 28 = 784
    yy = tf.cast(yy, dtype=tf.int32)
    # yy = tf.one_hot(yy, depth=10)
    return xx, yy


(x, y), (x_test, y_test) = datasets.fashion_mnist.load_data()
print('x, y: ', x.shape, y.shape)
print('x_test, y_test: ', x_test.shape, y_test.shape)

batch_size = 128

db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(10000).batch(batch_size=batch_size)
print(db)

db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.map(preprocess).batch(batch_size=batch_size)
print(db_test)

# 定义layers
model = Sequential([
    # 784 ==> 256
    layers.Dense(256, activation=tf.nn.relu),
    # 256 ==> 128
    layers.Dense(128, activation=tf.nn.relu),
    # 128 ==> 64
    layers.Dense(64, activation=tf.nn.relu),
    # 64 ==> 32
    layers.Dense(32, activation=tf.nn.relu),
    # 32 ==> 10
    layers.Dense(10)
])

# 初始化
model.build(input_shape=[None, 28 * 28])
model.summary()

# 优化器
optimizer = optimizers.Adam(lr=1e-3)


def main():
    # epoch
    for epoch in range(10):
        # step
        for step, (x, y) in enumerate(db):

            x = tf.reshape(x, [-1, 784])
            # print(x.shape)

            with tf.GradientTape() as tape:
                logites = model(x)
                y_hot = tf.one_hot(y, depth=10)
                loss_mse = tf.reduce_mean(tf.losses.MSE(y_hot, logites))
                loss_entropy = tf.reduce_mean(tf.losses.categorical_crossentropy(y_hot, logites, from_logits=True))
            gradient = tape.gradient(loss_entropy, model.trainable_variables)
            optimizer.apply_gradients(zip(gradient, model.trainable_variables))

            if step % 100 == 0:
                print(epoch, step, ', entropy loss: ', float(loss_entropy), ', mse loss: ', float(loss_mse))

        total_correct = 0
        total_num = 0
        for x, y in db_test:
            x = tf.reshape(x, [-1, 784])
            logites = model(x)

            probability = tf.nn.softmax(logites, axis=1)

            prediction = tf.argmax(probability, axis=1)
            prediction = tf.cast(prediction, tf.int32)

            correct = tf.equal(prediction, y)
            correct = tf.reduce_sum(tf.cast(correct, tf.int32))

            total_correct += int(tf.cast(correct, tf.int32))
            total_num += x.shape[0]

        accuracy = total_correct / total_num
        print(epoch, "accuracy: ", accuracy)


if __name__ == '__main__':
    main()
